{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/VincentStimper/normalizing-flows/blob/master/examples/real_nvp_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dZOHzj6Zf3bh"
   },
   "source": [
    "# Illustration of the Usage of the `normflows` Package\n",
    "## Training a Real NVP model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DWKHE21nf8ue"
   },
   "source": [
    "This notebook illustrates how to use the `normflows` packages by training a simple [Real NVP](https://arxiv.org/abs/1605.08803) model to a 2D distribution consisting on two half moons.\n",
    "\n",
    "Before we can start, we have to install the package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BUbS_OlPgXNb",
    "outputId": "b4ce77f5-bc0a-4e70-b1c1-3b70f0d98f97"
   },
   "outputs": [],
   "source": [
    "!pip install normflows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6kdLk_paf3bk"
   },
   "outputs": [],
   "source": [
    "# Import required packages\n",
    "import torch\n",
    "import numpy as np\n",
    "import normflows as nf\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OCPX0icbgrEA"
   },
   "source": [
    "After importing the required packages, we want to create a `nf.NormalizingFlow` model. Therefore, we need a base distribution, which we set to be a Gaussian, and a list of flow layers. The flow layers are simply affine coupling layers, whereby `nf.AffineCouplingBlock` already includes the splitting and merging of the features as it is done in coupling. We also swap the features after each layer to ensure that they are all modified."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ta_0PfGqf3bm",
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Set up model\n",
    "\n",
    "# Define 2D Gaussian base distribution\n",
    "base = nf.distributions.base.DiagGaussian(2)\n",
    "\n",
    "# Define list of flows\n",
    "num_layers = 32\n",
    "flows = []\n",
    "for i in range(num_layers):\n",
    "    # Neural network with two hidden layers having 64 units each\n",
    "    # Last layer is initialized by zeros making training more stable\n",
    "    param_map = nf.nets.MLP([1, 64, 64, 2], init_zeros=True)\n",
    "    # Add flow layer\n",
    "    flows.append(nf.flows.AffineCouplingBlock(param_map))\n",
    "    # Swap dimensions\n",
    "    flows.append(nf.flows.Permute(2, mode='swap'))\n",
    "    \n",
    "# Construct flow model\n",
    "model = nf.NormalizingFlow(base, flows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zxKwEBUhf3bm"
   },
   "outputs": [],
   "source": [
    "# Move model on GPU if available\n",
    "enable_cuda = True\n",
    "device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7z191AKvh1KO"
   },
   "source": [
    "This is our target distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define target distribution\n",
    "target = nf.distributions.TwoMoons()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 867
    },
    "id": "5hOhv--5f3bn",
    "outputId": "c9e72ca6-53f2-4929-8800-1b0ee7077eac",
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Plot target distribution\n",
    "grid_size = 200\n",
    "xx, yy = torch.meshgrid(torch.linspace(-3, 3, grid_size), torch.linspace(-3, 3, grid_size))\n",
    "zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)\n",
    "zz = zz.to(device)\n",
    "\n",
    "log_prob = target.log_prob(zz).to('cpu').view(*xx.shape)\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 902
    },
    "id": "K1umY7b2f3bo",
    "outputId": "03363ac7-9b22-4056-f6f1-e9b7a9041e7e"
   },
   "outputs": [],
   "source": [
    "# Plot initial flow distribution\n",
    "model.eval()\n",
    "log_prob = model.log_prob(zz).to('cpu').view(*xx.shape)\n",
    "model.train()\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')\n",
    "plt.gca().set_aspect('equal', 'box')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HXkqF_W0h6FT"
   },
   "source": [
    "Now, we are ready to train the flow model. This can be done in a similar fashion as standard neural networks. Since we use samples from the target for training, we use the forward KL divergence as objective, which is equivalent to maximum likelihood."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "WC8o3MdCf3bp",
    "outputId": "2cf1a2f0-0833-4e10-9960-aac79087bce5",
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Train model\n",
    "max_iter = 4000\n",
    "num_samples = 2 ** 9\n",
    "show_iter = 500\n",
    "\n",
    "\n",
    "loss_hist = np.array([])\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)\n",
    "\n",
    "for it in tqdm(range(max_iter)):\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # Get training samples\n",
    "    x = target.sample(num_samples).to(device)\n",
    "    \n",
    "    # Compute loss\n",
    "    loss = model.forward_kld(x)\n",
    "    \n",
    "    # Do backprop and optimizer step\n",
    "    if ~(torch.isnan(loss) | torch.isinf(loss)):\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    \n",
    "    # Log loss\n",
    "    loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())\n",
    "    \n",
    "    # Plot learned distribution\n",
    "    if (it + 1) % show_iter == 0:\n",
    "        model.eval()\n",
    "        log_prob = model.log_prob(zz)\n",
    "        model.train()\n",
    "        prob = torch.exp(log_prob.to('cpu').view(*xx.shape))\n",
    "        prob[torch.isnan(prob)] = 0\n",
    "\n",
    "        plt.figure(figsize=(15, 15))\n",
    "        plt.pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')\n",
    "        plt.gca().set_aspect('equal', 'box')\n",
    "        plt.show()\n",
    "\n",
    "# Plot loss\n",
    "plt.figure(figsize=(10, 10))\n",
    "plt.plot(loss_hist, label='loss')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ycFsV2y3kQnt"
   },
   "source": [
    "This is our trained flow model!\n",
    "\n",
    "Note that there might be a density filament connecting the two modes, which is due to an architectural limitation of normalizing flows, especially prominent in Real NVP. You can find out more about it in [this paper](https://proceedings.mlr.press/v151/stimper22a)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 867
    },
    "id": "G5ZcFoG-f3bq",
    "outputId": "3883a118-1183-49c8-8a7e-6808908b7045"
   },
   "outputs": [],
   "source": [
    "# Plot target distribution\n",
    "f, ax = plt.subplots(1, 2, sharey=True, figsize=(15, 7))\n",
    "\n",
    "log_prob = target.log_prob(zz).to('cpu').view(*xx.shape)\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "ax[0].pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')\n",
    "\n",
    "ax[0].set_aspect('equal', 'box')\n",
    "ax[0].set_axis_off()\n",
    "ax[0].set_title('Target', fontsize=24)\n",
    "\n",
    "# Plot learned distribution\n",
    "model.eval()\n",
    "log_prob = model.log_prob(zz).to('cpu').view(*xx.shape)\n",
    "model.train()\n",
    "prob = torch.exp(log_prob)\n",
    "prob[torch.isnan(prob)] = 0\n",
    "\n",
    "ax[1].pcolormesh(xx, yy, prob.data.numpy(), cmap='coolwarm')\n",
    "\n",
    "ax[1].set_aspect('equal', 'box')\n",
    "ax[1].set_axis_off()\n",
    "ax[1].set_title('Real NVP', fontsize=24)\n",
    "\n",
    "plt.subplots_adjust(wspace=0.1)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "real_nvp_colab.ipynb",
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
